{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'a': 0,\n",
       " 'b': 1,\n",
       " 'c': 2,\n",
       " 'd': 3,\n",
       " 'e': 4,\n",
       " 'f': 5,\n",
       " 'g': 6,\n",
       " 'h': 7,\n",
       " 'i': 8,\n",
       " 'j': 9,\n",
       " 'k': 10,\n",
       " 'l': 11,\n",
       " 'm': 12,\n",
       " 'n': 13,\n",
       " 'o': 14,\n",
       " 'p': 15,\n",
       " 'q': 16,\n",
       " 'r': 17,\n",
       " 's': 18,\n",
       " 't': 19,\n",
       " 'u': 20,\n",
       " 'v': 21,\n",
       " 'w': 22,\n",
       " 'x': 23,\n",
       " 'y': 24,\n",
       " 'z': 25}"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import string\n",
    "\n",
    "# 定义字典\n",
    "char2indx = {s: i for i, s in enumerate(sorted(string.ascii_lowercase))}\n",
    "char2indx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['l', 'o', 'v', 'e']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example = list('love')\n",
    "example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([11, 14, 21,  4]), torch.Size([4]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 利用字典，对文本进行数字化\n",
    "idx = []\n",
    "\n",
    "for i in example:\n",
    "    idx.append(char2indx[i])\n",
    "\n",
    "idx = torch.tensor(idx)\n",
    "idx, idx.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 1., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0.]]),\n",
       " torch.Size([4, 26]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用独热编码，将文本转换为二维张量\n",
    "num_claz = 26\n",
    "dims = 5\n",
    "x = F.one_hot(idx, num_classes=num_claz).float()\n",
    "x, x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.7867, -1.8944,  0.1891, -3.3317,  0.4883],\n",
       "        [-1.3727,  1.1942,  0.1609, -1.8016,  0.3551],\n",
       "        [ 0.0374,  0.9542,  0.1898, -0.4440,  1.4332],\n",
       "        [-1.0798,  0.7559,  0.9129,  0.4616, -0.2050]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 文本嵌入其实就是张量乘法\n",
    "x                                 # ( 4, 26)\n",
    "W = torch.randn((num_claz, dims)) # (26,  5)\n",
    "x @ W                             # ( 4,  5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.7867, -1.8944,  0.1891, -3.3317,  0.4883],\n",
       "        [-1.3727,  1.1942,  0.1609, -1.8016,  0.3551],\n",
       "        [ 0.0374,  0.9542,  0.1898, -0.4440,  1.4332],\n",
       "        [-1.0798,  0.7559,  0.9129,  0.4616, -0.2050]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 与前面张量乘法一致，但更加友好的实现方式\n",
    "# 因为运算涉及的张量idx维度更少，而且不需要经过独热编码\n",
    "idx    # ( 4)\n",
    "W      # (26, 5)\n",
    "W[idx] # ( 4, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 文字嵌入的实现示例\n",
    "class Embedding:\n",
    "    \n",
    "    def __init__(self, num_embeddings, embedding_dim):\n",
    "        self.weight = torch.randn((num_embeddings, embedding_dim))\n",
    "\n",
    "    def __call__(self, idx):\n",
    "        self.out = self.weight[idx]\n",
    "        return self.out\n",
    "\n",
    "    def parameters(self):\n",
    "        return [self.weight]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.9700,  1.7496, -1.6055,  0.6170, -0.3594],\n",
       "         [-1.3329, -0.3346,  0.6670, -0.2516,  0.6160],\n",
       "         [-0.9252,  0.7330,  0.0849, -0.2643,  0.1934],\n",
       "         [-0.2149, -0.4215,  1.2895, -0.6259,  0.9605]]),\n",
       " torch.Size([4]),\n",
       " torch.Size([4, 5]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 正确的使用方式\n",
    "emb = Embedding(num_claz, dims)\n",
    "emb(idx), idx.shape, emb(idx).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([10, 11]), torch.Size([10, 11, 5]))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 维度更多的例子\n",
    "# 可以将bidx理解成10个长度等于11的文本（文本的单元是字母）\n",
    "bidx = torch.randint(0, num_claz, (10, 11))\n",
    "bidx.shape, emb(bidx).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [-0.4473,  1.5996,  1.8102, -1.1696,  0.2618],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100]],\n",
       "\n",
       "        [[ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [-0.4473,  1.5996,  1.8102, -1.1696,  0.2618],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100]],\n",
       "\n",
       "        [[ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [-0.4473,  1.5996,  1.8102, -1.1696,  0.2618],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100]],\n",
       "\n",
       "        [[ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [-0.4473,  1.5996,  1.8102, -1.1696,  0.2618],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100],\n",
       "         [ 0.5768,  0.0849, -1.4448, -1.1311,  0.3100]]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 错误的使用方式\n",
    "# x是独热编码的结果\n",
    "emb(x.int())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
